{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MulAndDiv\n",
    "\n",
    "参考：`tvm/src/relay/quantize/realize.cc`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```c++\n",
    "/* calculate `data * s1 / s2`, use shift if possible */\n",
    "inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,\n",
    "                      const Array<IndexExpr>& data_shape) {\n",
    "  const QConfig& cfg = QConfig::Current();\n",
    "  // here we assume the dtype of data is dtype activation\n",
    "  if (s1 == s2) return data;\n",
    "\n",
    "  float factor = s1 / s2;\n",
    "  float shift_factor = std::log2(factor);\n",
    "  ICHECK_GT(shift_factor, 0);\n",
    "  if (static_cast<int>(shift_factor) == shift_factor) {\n",
    "    return LeftShift(data, MakeConstantScalar(dtype, static_cast<int>(shift_factor)));\n",
    "  } else if (static_cast<int>(factor) == factor) {\n",
    "    return Multiply(data, MakeConstantScalar(dtype, factor));\n",
    "  } else {\n",
    "    if (cfg->rounding == \"UPWARD\") {\n",
    "      auto [fixed_point_multiplier, shift] = qnn::GetFixedPointMultiplierShift(factor);\n",
    "      data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);\n",
    "    } else {\n",
    "      data = qnn::FixedPointMultiplyToNearest(data, factor, data_shape);\n",
    "    }\n",
    "\n",
    "    return Cast(data, dtype);\n",
    "  }\n",
    "}\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码定义了一个名为`MulAndDiv`的内联函数，用于计算 `data * s1 / s2`。如果可能的话，它会使用位移运算来优化计算过程。\n",
    "\n",
    "函数接收5个参数：\n",
    "- `data`：需要进行计算的数据；\n",
    "- `s1` 和 `s2`：两个浮点数，用于计算 `data * s1 / s2`；\n",
    "- `dtype`：数据类型；\n",
    "- `data_shape`：数据的形状。\n",
    "\n",
    "函数首先获取当前的量化配置（`QConfig::Current()`），然后判断 `s1` 和 `s2` 是否相等，如果相等则直接返回 `data`。\n",
    "\n",
    "接下来，计算 `factor = s1 / s2`，并计算 `shift_factor = std::log2(factor)`。如果 `shift_factor` 大于 0 且为整数，则对 `data` 进行左移运算。如果 `factor` 为整数，则对 `data` 进行乘法运算。否则，根据量化配置中的舍入方式（`cfg->rounding`）进行定点乘法运算，并将结果转换为指定的数据类型。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py312x",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
